#!/usr/bin/env python3

"""
Test for ReduceMax with boolean tensors.

For boolean tensors, ReduceMax is equivalent to logical OR.
"""

import numpy as np
import onnx
from onnx import helper, TensorProto


def create_model():
    """Create ONNX model with ReduceMax operations on boolean tensors."""

    # Create input tensor (bool, shape [2, 3, 4])
    input_tensor = helper.make_tensor_value_info('input', TensorProto.BOOL, [2, 3, 4])

    # Output 1: Reduce all dimensions, no keepdims -> scalar bool
    reduce1 = helper.make_node(
        'ReduceMax',
        inputs=['input'],
        outputs=['output1'],
        keepdims=0,
    )

    # Output 2: Reduce all dimensions, keepdims=1 -> shape [1, 1, 1]
    reduce2 = helper.make_node(
        'ReduceMax',
        inputs=['input'],
        outputs=['output2'],
        keepdims=1,
    )

    # Output 3: Reduce along axis 2, no keepdims -> shape [2, 3]
    reduce3 = helper.make_node(
        'ReduceMax',
        inputs=['input'],
        outputs=['output3'],
        axes=[2],
        keepdims=0,
    )

    # Output 4: Reduce along axes [0, 2], keepdims=1 -> shape [1, 3, 1]
    reduce4 = helper.make_node(
        'ReduceMax',
        inputs=['input'],
        outputs=['output4'],
        axes=[0, 2],
        keepdims=1,
    )

    # Create output tensors
    output1_tensor = helper.make_tensor_value_info('output1', TensorProto.BOOL, [])
    output2_tensor = helper.make_tensor_value_info('output2', TensorProto.BOOL, [1, 1, 1])
    output3_tensor = helper.make_tensor_value_info('output3', TensorProto.BOOL, [2, 3])
    output4_tensor = helper.make_tensor_value_info('output4', TensorProto.BOOL, [1, 3, 1])

    # Create graph
    graph = helper.make_graph(
        nodes=[reduce1, reduce2, reduce3, reduce4],
        name='reduce_max_bool_model',
        inputs=[input_tensor],
        outputs=[output1_tensor, output2_tensor, output3_tensor, output4_tensor],
    )

    # Create model
    model = helper.make_model(
        graph,
        producer_name='burn-import-test',
        opset_imports=[helper.make_opsetid("", 16)]
    )

    # Check model
    onnx.checker.check_model(model)

    return model


def generate_test_data(model):
    """Generate test data and verify outputs using ReferenceEvaluator."""

    # Input data: [2, 3, 4] boolean tensor
    # Using a mix of True and False values
    input_data = np.array([
        [
            [False, False, False, False], # All False
            [True, True, True, True],     # All True
            [False, True, False, True],   # Mixed
        ],
        [
            [False, False, False, False], # All False
            [False, True, False, False],  # Mostly False
            [True, False, True, False],   # Mixed
        ]
    ], dtype=bool)

    print("=" * 80)
    print("Test data for reduce_max_bool (ReduceMax on boolean = logical OR):")
    print("=" * 80)
    print()

    print("Input shape:", input_data.shape)
    print("Input data:")
    print(input_data)
    print()

    # Verify with ONNX ReferenceEvaluator
    try:
        from onnx.reference import ReferenceEvaluator

        sess = ReferenceEvaluator(model)
        outputs = sess.run(None, {"input": input_data})

        print("ONNX Model Outputs (using ReferenceEvaluator):")
        print()

        # Output 1: Reduce all -> scalar
        print("Output 1 (reduce all, no keepdims):")
        print(f"  Shape: {outputs[0].shape if hasattr(outputs[0], 'shape') else 'scalar'}")
        print(f"  Value: {outputs[0]}")
        print()

        # Output 2: Reduce all with keepdims
        print("Output 2 (reduce all, keepdims=1):")
        print(f"  Shape: {outputs[1].shape}")
        print(f"  Value: {outputs[1]}")
        print()

        # Output 3: Reduce along axis 2
        print("Output 3 (reduce axis 2, no keepdims):")
        print(f"  Shape: {outputs[2].shape}")
        print(f"  Value:")
        print(outputs[2])
        print()

        # Output 4: Reduce along axes [0, 2] with keepdims
        print("Output 4 (reduce axes [0, 2], keepdims=1):")
        print(f"  Shape: {outputs[3].shape}")
        print(f"  Value:")
        print(outputs[3])
        print()

    except ImportError:
        print("onnx.reference not available, falling back to NumPy verification")
        print()

        # Fallback to NumPy
        output1 = np.max(input_data)
        output2 = np.max(input_data, keepdims=True)
        output3 = np.max(input_data, axis=2, keepdims=False)
        output4 = np.max(input_data, axis=(0, 2), keepdims=True)

        print("NumPy Outputs:")
        print(f"Output 1: {output1}")
        print(f"Output 2: {output2}")
        print(f"Output 3: {output3}")
        print(f"Output 4: {output4}")
        print()

    print("=" * 80)
    print("Explanation:")
    print("- For booleans, ReduceMax is equivalent to logical OR")
    print("- Output 1: OR of all 24 elements = True (there are True values)")
    print("- Output 3: OR along last axis for each [2, 3] position")
    print("  - [0,0,:] = All False = False")
    print("  - [0,1,:] = All True = True")
    print("  - [0,2,:] = False OR True OR False OR True = True")
    print("  - etc.")
    print("=" * 80)


if __name__ == '__main__':
    model = create_model()

    # Save model
    onnx.save(model, 'reduce_max_bool.onnx')
    print("✓ Saved reduce_max_bool.onnx")
    print()

    # Generate test data
    generate_test_data(model)
